{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"align-center\">\n",
    "<a href=\"https://oumi.ai/\"><img src=\"https://oumi.ai/docs/en/latest/_static/logo/header_logo.png\" height=\"200\"></a>\n",
    "\n",
    "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n",
    "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n",
    "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n",
    "</div>\n",
    "\n",
    "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n",
    "\n",
    "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](https://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n",
    "\n",
    "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n",
    "\n",
    "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bulk Inference\n",
    "\n",
    "This notebook demonstrates how to run bulk inference against various LLM APIss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prerequisites\n",
    "\n",
    "## Machine Requirements\n",
    "\n",
    "Since we're calling APIs remotely, no GPU or special hardware is needed. You may need internet access when calling an API that isn't accessible via your local network.\n",
    "\n",
    "## Oumi Installation\n",
    "\n",
    "First, let's install Oumi. You can find more detailed instructions about Oumi installation [here](https://oumi.ai/docs/en/latest/get_started/installation.html).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install oumi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## API Keys\n",
    "\n",
    "Different model providers each have their own API keys which must be set when calling each one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from oumi.core.configs import InferenceConfig\n",
    "from oumi.core.types import Conversation, Message, Role\n",
    "\n",
    "# Set up API keys for different providers\n",
    "if (\n",
    "    not os.environ.get(\"ANTHROPIC_API_KEY\")\n",
    "    or os.environ.get(\"ANTHROPIC_API_KEY\") == \"<MY_ANTHROPIC_API_KEY>\"\n",
    "):\n",
    "    os.environ[\"ANTHROPIC_API_KEY\"] = \"<MY_ANTHROPIC_API_KEY>\"\n",
    "\n",
    "if (\n",
    "    not os.environ.get(\"OPENAI_API_KEY\")\n",
    "    or os.environ.get(\"OPENAI_API_KEY\") == \"<MY_OPENAI_API_KEY>\"\n",
    "):\n",
    "    os.environ[\"OPENAI_API_KEY\"] = \"<MY_OPENAI_API_KEY>\"\n",
    "\n",
    "if (\n",
    "    not os.environ.get(\"GEMINI_API_KEY\")\n",
    "    or os.environ.get(\"GEMINI_API_KEY\") == \"<MY_GEMINI_API_KEY>\"\n",
    "):\n",
    "    os.environ[\"GEMINI_API_KEY\"] = \"<MY_GEMINI_API_KEY>\"\n",
    "\n",
    "# Display which API keys are configured\n",
    "anthropic_key = os.environ.get(\"ANTHROPIC_API_KEY\")\n",
    "openai_key = os.environ.get(\"OPENAI_API_KEY\")\n",
    "google_key = os.environ.get(\"GEMINI_API_KEY\")\n",
    "\n",
    "print(f\"Using Anthropic API Key: '{anthropic_key}'\")\n",
    "print(f\"Using OpenAI API Key: '{openai_key}'\")\n",
    "print(f\"Using Google API Key: '{google_key}'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting up the config file\n",
    "\n",
    "Note: in this section we are writing the config file to the current working directory.\n",
    "\n",
    "An alternative option is to initialize the params classes directly: `ModelParams`, `GenerationParams`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path = \"api_tutorial_inference_config.yaml\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile api_tutorial_inference_config.yaml\n",
    "\n",
    "model:\n",
    "  model_name: \"claude-sonnet-4-20250514\"\n",
    "  # model_name: \"gpt-4o-mini\"\n",
    "  # model_name: \"gemini-2.0-flash\"\n",
    "\n",
    "generation:\n",
    "  max_new_tokens: 512\n",
    "\n",
    "remote_params:\n",
    "  num_workers: 5 # max number of workers to run in parallel\n",
    "  politeness_policy: 60 # wait 60 seconds before sending next request\n",
    "\n",
    "engine: \"ANTHROPIC\"\n",
    "# engine: \"OPENAI\"\n",
    "# engine: \"GOOGLE_GEMINI\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the model and the inference engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.builders.inference_engines import build_inference_engine\n",
    "\n",
    "config = InferenceConfig.from_yaml(config_path)\n",
    "model_params = config.model\n",
    "remote_params = config.remote_params\n",
    "engine_type = config.engine\n",
    "if not engine_type:\n",
    "    print(\"Check your config file to ensure you have an engine type specified.\")\n",
    "    exit()\n",
    "\n",
    "inference_engine = build_inference_engine(engine_type, model_params, remote_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing our inputs\n",
    "\n",
    "The inference engine expects a list of conversations, where each conversation is a list of messages.\n",
    "\n",
    "See the [Conversation](https://github.com/oumi-ai/oumi/blob/38b3d2b27407be5fc9be5a1dd88f9ad518f3491c/src/oumi/core/types/turn.py#L109) class for more details.\n",
    "\n",
    "Tip: you can visualize how the conversation is rendered as a prompt with the following:\n",
    "\n",
    "```python\n",
    "inference_engine.apply_chat_template(conversation, tokenize=False)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "ds = load_dataset(\"fka/awesome-chatgpt-prompts\", split=\"train\")\n",
    "\n",
    "prompts = [sample[\"prompt\"] for sample in ds]  # type: ignore\n",
    "\n",
    "conversations = [\n",
    "    Conversation(\n",
    "        messages=[\n",
    "            Message(role=Role.USER, content=prompt),\n",
    "        ]\n",
    "    )\n",
    "    for prompt in prompts\n",
    "]\n",
    "\n",
    "print(len(conversations))\n",
    "print(conversations[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Limit the number of conversations to 10 for testing\n",
    "\n",
    "inference_conversations = conversations[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Running inference for {len(inference_conversations)} conversations\")\n",
    "\n",
    "generations = inference_engine.infer(\n",
    "    input=inference_conversations,\n",
    "    inference_config=config,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for conversation in generations:\n",
    "    print(repr(conversation))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference Features\n",
    "#### Parallel Inference\n",
    "Under the hood, oumi will parallelize your inference, up to a maximum of `num_workers` requests.\n",
    "\n",
    "For example, if `num_workers = 5`, oumi will process 5 requests at a time.\n",
    "\n",
    "Make sure to feed all your prompts to the engine at once for maximum throughput.\n",
    "\n",
    "#### Rate Limiting\n",
    "In addition to parallization, oumi offers support for rate limiting through `politeness_policy`.\n",
    "\n",
    "For example, if `politeness_policy = 60`, then each worker will wait 60s before submitting its next request.\n",
    "\n",
    "This is useful for matching the request-per-minute quota limits set by various API providers, ensuring that your inference job succeeds.\n",
    "\n",
    "#### Adaptive Throughput\n",
    "By default, oumi will adapt its parallelization (num_workers) based on error rate, up to the maximum set in `num_workers`.\n",
    "\n",
    "To start, oumi will run inference at 50% of `num_workers`, then scale up as requests are made successfully.\n",
    "\n",
    "One inference begins getting failed requests, inference enters a `backoff` state, where the number of active workers is reduced by some fraction.\n",
    "\n",
    "After entering `backoff`, if inference continues to experience a high error rate, oumi continues to `backoff`.\n",
    "\n",
    "After a period of stability and successful requests, oumi enters a `warmup` state and begins increasing the number of active workers up to the maximum `num_workers`.\n",
    "\n",
    "This option can be disabled by setting `use_adaptive_concurrency = False` in the `RemoteParams`.\n",
    "\n",
    "#### Progress Saving\n",
    "Whether you specify an output directory or not, oumi automatically saves results as it receives them to your local disk, ensuring that your inference job loses no progress in the event the API goes down.\n",
    "\n",
    "Additionally, if you rerun inference with the same dataset, model, and generation parameters, oumi will resume from where you left off!\n",
    "\n",
    "**Scratch file locations:**\n",
    "- **With output path specified**: Files are saved in `output_path/scratch/<filename>_<hash>.<extension>` where the hash is based on your model parameters, generation parameters, and dataset content\n",
    "- **Without output path**: Files are saved in `~/.cache/oumi/tmp/temp_inference_output_<hash>.jsonl`\n",
    "\n",
    "The hash-based naming ensures cache consistency and prevents conflicts between different inference configurations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🧭 What's Next?\n",
    "\n",
    "Congrats on finishing this notebook! Feel free to check out our other [notebooks](https://github.com/oumi-ai/oumi/tree/main/notebooks) in the [Oumi GitHub](https://github.com/oumi-ai/oumi), and give us a star! You can also join the Oumi community over on [Discord](https://discord.gg/oumi).\n",
    "\n",
    "📰 Want to keep up with news from Oumi? Subscribe to our [Substack](https://blog.oumi.ai/) and [Youtube](https://www.youtube.com/@Oumi_AI)!\n",
    "\n",
    "⚡ Interested in building custom AI in hours, not months? Apply to get [early access](https://oumi-ai.typeform.com/early-access) to the Oumi Platform, or [chat with us](https://calendly.com/d/ctcx-nps-47m/chat-with-us-get-early-access-to-the-oumi-platform) to learn more!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "oumi-demo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
